Investment firms, hedge funds and even individuals have been using financial models to better understand market behavior and make profitable investments and trades. A wealth of information is available in the form of historical stock prices and company performance data, suitable for machine learning algorithms to process.
For this project, your task our task is to build a stock price predictor that takes daily trading data over a certain date range as input, and outputs projected estimates for given query dates.
# import libraries
import pandas as pd
from pandas_datareader import data as pdr
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler
import datetime
import time
import math
import torch
import torch.nn as nn
start_date = datetime.datetime(2015, 1, 1)
end_date = datetime.datetime(2020,12,31)
#Tickers : GM AMZN BTC-USD
df = pdr.get_data_yahoo('GM', start=start_date, end=end_date).reset_index()
price = df[['Adj Close']]
scaler = MinMaxScaler(feature_range=(-1, 1))
price['Adj Close'] = scaler.fit_transform(price['Adj Close'].values.reshape(-1,1))
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:10: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
def split_data(stock, lookback):
"""
Split the daily data in a training set an a test set
Parameters
----------
stock : DataFrame
Daily data
lookback : int
Lookback window
Returns
-------
x_train, y_train, x_test, y_test : arrays
The training sets and testing sets
"""
data_raw = stock.to_numpy() # convert to numpy array
data = []
# create all possible sequences of length seq_len
for index in range(len(data_raw) - lookback):
data.append(data_raw[index: index + lookback])
data = np.array(data);
test_set_size = int(np.round(0.2*data.shape[0]));
train_set_size = data.shape[0] - (test_set_size);
x_train = data[:train_set_size,:-1,:]
y_train = data[:train_set_size,-1,:]
x_test = data[train_set_size:,:-1]
y_test = data[train_set_size:,-1,:]
return [x_train, y_train, x_test, y_test]
lookback = 10 # choose sequence length
x_train, y_train, x_test, y_test = split_data(price, lookback)
x_train = torch.from_numpy(x_train).type(torch.Tensor)
x_test = torch.from_numpy(x_test).type(torch.Tensor)
y_train_lstm = torch.from_numpy(y_train).type(torch.Tensor)
y_test_lstm = torch.from_numpy(y_test).type(torch.Tensor)
y_train_gru = torch.from_numpy(y_train).type(torch.Tensor)
y_test_gru = torch.from_numpy(y_test).type(torch.Tensor)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(device)
#Models parameters
input_dim = 1
hidden_dim = 64 #32 64 128
num_layers = 2 #2 , 3
output_dim = 1
num_epochs = 500 #100 250 500 1000
cuda
class LSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
super(LSTM, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_().to(device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_().to(device)
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
out = self.fc(out[:, -1, :]).to(device)
return out
model = LSTM(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=num_layers)
criterion = torch.nn.MSELoss(reduction='mean')
#optimiser = torch.optim.SGD(model.parameters(), lr=0.01)
optimiser = torch.optim.Adam(model.parameters(), lr=0.01)
model.to(device)
criterion.to(device)
x_train = x_train.to(device)
y_train_lstm = y_train_lstm.to(device)
hist = np.zeros(num_epochs)
start_time = time.time()
lstm = []
for t in range(num_epochs):
y_train_pred = model(x_train).to(device)
loss = criterion(y_train_pred, y_train_lstm)
print("Epoch ", t, "MSE: ", loss.item())
hist[t] = loss.item()
optimiser.zero_grad()
loss.backward()
optimiser.step()
training_time = time.time()-start_time
print("Training time: {}".format(training_time))
Epoch 0 MSE: 0.10327287763357162 Epoch 1 MSE: 0.10609275102615356 Epoch 2 MSE: 0.084849514067173 Epoch 3 MSE: 0.07953985780477524 Epoch 4 MSE: 0.06514491885900497 Epoch 5 MSE: 0.029029162600636482 Epoch 6 MSE: 0.018505150452256203 Epoch 7 MSE: 0.12133926153182983 Epoch 8 MSE: 0.022022387012839317 Epoch 9 MSE: 0.034502267837524414 Epoch 10 MSE: 0.04111925885081291 Epoch 11 MSE: 0.045831792056560516 Epoch 12 MSE: 0.049516815692186356 Epoch 13 MSE: 0.051705021411180496 Epoch 14 MSE: 0.0520317442715168 Epoch 15 MSE: 0.05007021874189377 Epoch 16 MSE: 0.04515652358531952 Epoch 17 MSE: 0.036432500928640366 Epoch 18 MSE: 0.023690905421972275 Epoch 19 MSE: 0.011560875922441483 Epoch 20 MSE: 0.01673886924982071 Epoch 21 MSE: 0.0173213891685009 Epoch 22 MSE: 0.009943884797394276 Epoch 23 MSE: 0.017121853306889534 Epoch 24 MSE: 0.017338145524263382 Epoch 25 MSE: 0.009654054418206215 Epoch 26 MSE: 0.0052611446008086205 Epoch 27 MSE: 0.006763740908354521 Epoch 28 MSE: 0.010432498529553413 Epoch 29 MSE: 0.012153762392699718 Epoch 30 MSE: 0.010830311104655266 Epoch 31 MSE: 0.008078470826148987 Epoch 32 MSE: 0.005913082044571638 Epoch 33 MSE: 0.005378433037549257 Epoch 34 MSE: 0.0062453122809529305 Epoch 35 MSE: 0.0071733491495251656 Epoch 36 MSE: 0.006912760902196169 Epoch 37 MSE: 0.005903179757297039 Epoch 38 MSE: 0.0054878066293895245 Epoch 39 MSE: 0.00565411988645792 Epoch 40 MSE: 0.005462123546749353 Epoch 41 MSE: 0.004951308947056532 Epoch 42 MSE: 0.004792962688952684 Epoch 43 MSE: 0.005101779010146856 Epoch 44 MSE: 0.0054450989700853825 Epoch 45 MSE: 0.005403035320341587 Epoch 46 MSE: 0.0049084387719631195 Epoch 47 MSE: 0.004292689263820648 Epoch 48 MSE: 0.004057419020682573 Epoch 49 MSE: 0.004382362123578787 Epoch 50 MSE: 0.004800620023161173 Epoch 51 MSE: 0.0047332593239843845 Epoch 52 MSE: 0.00429498078301549 Epoch 53 MSE: 0.004010020289570093 Epoch 54 MSE: 0.00405229302123189 Epoch 55 MSE: 0.0042020706459879875 Epoch 56 MSE: 0.00423399405553937 Epoch 57 MSE: 0.004127885214984417 Epoch 58 MSE: 0.004010280594229698 Epoch 59 MSE: 0.0039786044508218765 Epoch 60 MSE: 0.003988081123679876 Epoch 61 MSE: 0.003933039028197527 Epoch 62 MSE: 0.003823513863608241 Epoch 63 MSE: 0.003776049939915538 Epoch 64 MSE: 0.003824274055659771 Epoch 65 MSE: 0.003863774472847581 Epoch 66 MSE: 0.003804313950240612 Epoch 67 MSE: 0.0036872613709419966 Epoch 68 MSE: 0.0036205435171723366 Epoch 69 MSE: 0.003640422597527504 Epoch 70 MSE: 0.0036781001836061478 Epoch 71 MSE: 0.003658491186797619 Epoch 72 MSE: 0.0035898007918149233 Epoch 73 MSE: 0.0035334117710590363 Epoch 74 MSE: 0.003518172074109316 Epoch 75 MSE: 0.003517123172059655 Epoch 76 MSE: 0.0034970585256814957 Epoch 77 MSE: 0.0034619851503521204 Epoch 78 MSE: 0.0034362394362688065 Epoch 79 MSE: 0.0034217468928545713 Epoch 80 MSE: 0.0033987327478826046 Epoch 81 MSE: 0.0033639115281403065 Epoch 82 MSE: 0.003336795140057802 Epoch 83 MSE: 0.0033262858632951975 Epoch 84 MSE: 0.0033164448104798794 Epoch 85 MSE: 0.003290680004283786 Epoch 86 MSE: 0.0032549300231039524 Epoch 87 MSE: 0.0032281929161399603 Epoch 88 MSE: 0.0032156326342374086 Epoch 89 MSE: 0.003202912863343954 Epoch 90 MSE: 0.0031789366621524096 Epoch 91 MSE: 0.003150469856336713 Epoch 92 MSE: 0.0031281677074730396 Epoch 93 MSE: 0.003110940335318446 Epoch 94 MSE: 0.0030915706884115934 Epoch 95 MSE: 0.003068988909944892 Epoch 96 MSE: 0.0030475202947854996 Epoch 97 MSE: 0.0030280023347586393 Epoch 98 MSE: 0.003006906481459737 Epoch 99 MSE: 0.0029837179463356733 Epoch 100 MSE: 0.002962325233966112 Epoch 101 MSE: 0.00294388341717422 Epoch 102 MSE: 0.002924472326412797 Epoch 103 MSE: 0.002901745028793812 Epoch 104 MSE: 0.0028787213377654552 Epoch 105 MSE: 0.00285856775008142 Epoch 106 MSE: 0.0028396756388247013 Epoch 107 MSE: 0.002818876411765814 Epoch 108 MSE: 0.002796490676701069 Epoch 109 MSE: 0.002774963853880763 Epoch 110 MSE: 0.0027546226046979427 Epoch 111 MSE: 0.0027338836807757616 Epoch 112 MSE: 0.002712367568165064 Epoch 113 MSE: 0.002691052621230483 Epoch 114 MSE: 0.0026700596790760756 Epoch 115 MSE: 0.002648686757311225 Epoch 116 MSE: 0.0026271298993378878 Epoch 117 MSE: 0.0026060682721436024 Epoch 118 MSE: 0.0025851137470453978 Epoch 119 MSE: 0.0025634451303631067 Epoch 120 MSE: 0.0025413997936993837 Epoch 121 MSE: 0.002519859466701746 Epoch 122 MSE: 0.0024986269418150187 Epoch 123 MSE: 0.0024769043084234 Epoch 124 MSE: 0.00245482474565506 Epoch 125 MSE: 0.0024330804590135813 Epoch 126 MSE: 0.0024115964770317078 Epoch 127 MSE: 0.00238981656730175 Epoch 128 MSE: 0.0023677637800574303 Epoch 129 MSE: 0.002345810877159238 Epoch 130 MSE: 0.0023239431902766228 Epoch 131 MSE: 0.002301996573805809 Epoch 132 MSE: 0.0022800967562943697 Epoch 133 MSE: 0.0022583564277738333 Epoch 134 MSE: 0.0022366882767528296 Epoch 135 MSE: 0.00221512233838439 Epoch 136 MSE: 0.002193776424974203 Epoch 137 MSE: 0.0021725890692323446 Epoch 138 MSE: 0.0021515486296266317 Epoch 139 MSE: 0.0021308877039700747 Epoch 140 MSE: 0.0021107173524796963 Epoch 141 MSE: 0.0020909467712044716 Epoch 142 MSE: 0.0020716742146760225 Epoch 143 MSE: 0.0020530743058770895 Epoch 144 MSE: 0.0020350839477032423 Epoch 145 MSE: 0.0020177073311060667 Epoch 146 MSE: 0.0020011684391647577 Epoch 147 MSE: 0.0019855087157338858 Epoch 148 MSE: 0.001970660872757435 Epoch 149 MSE: 0.001956736436113715 Epoch 150 MSE: 0.001943748677149415 Epoch 151 MSE: 0.0019316324032843113 Epoch 152 MSE: 0.001920461654663086 Epoch 153 MSE: 0.001910197315737605 Epoch 154 MSE: 0.0019007594091817737 Epoch 155 MSE: 0.0018921061418950558 Epoch 156 MSE: 0.0018841071287170053 Epoch 157 MSE: 0.0018767131259664893 Epoch 158 MSE: 0.001869786879979074 Epoch 159 MSE: 0.0018631998682394624 Epoch 160 MSE: 0.001856886432506144 Epoch 161 MSE: 0.0018506691558286548 Epoch 162 MSE: 0.00184451334644109 Epoch 163 MSE: 0.0018383247079327703 Epoch 164 MSE: 0.0018320315284654498 Epoch 165 MSE: 0.001825616112910211 Epoch 166 MSE: 0.0018190358532592654 Epoch 167 MSE: 0.0018123004119843245 Epoch 168 MSE: 0.0018053972162306309 Epoch 169 MSE: 0.001798370503820479 Epoch 170 MSE: 0.0017912328476086259 Epoch 171 MSE: 0.0017840288346633315 Epoch 172 MSE: 0.0017768022371456027 Epoch 173 MSE: 0.0017695720307528973 Epoch 174 MSE: 0.0017623926978558302 Epoch 175 MSE: 0.0017552823992446065 Epoch 176 MSE: 0.001748264068737626 Epoch 177 MSE: 0.0017413527239114046 Epoch 178 MSE: 0.0017345469677820802 Epoch 179 MSE: 0.0017278512241318822 Epoch 180 MSE: 0.0017212586244568229 Epoch 181 MSE: 0.001714757177978754 Epoch 182 MSE: 0.0017083346610888839 Epoch 183 MSE: 0.0017019727965816855 Epoch 184 MSE: 0.001695660874247551 Epoch 185 MSE: 0.0016893836436793208 Epoch 186 MSE: 0.0016831276006996632 Epoch 187 MSE: 0.0016768834320828319 Epoch 188 MSE: 0.0016706380993127823 Epoch 189 MSE: 0.001664385781623423 Epoch 190 MSE: 0.0016581222880631685 Epoch 191 MSE: 0.0016518420306965709 Epoch 192 MSE: 0.0016455440782010555 Epoch 193 MSE: 0.0016392271500080824 Epoch 194 MSE: 0.0016328906640410423 Epoch 195 MSE: 0.0016265381127595901 Epoch 196 MSE: 0.0016201696125790477 Epoch 197 MSE: 0.0016137883067131042 Epoch 198 MSE: 0.0016073979204520583 Epoch 199 MSE: 0.0016009998507797718 Epoch 200 MSE: 0.0015945985214784741 Epoch 201 MSE: 0.001588196842931211 Epoch 202 MSE: 0.0015817983075976372 Epoch 203 MSE: 0.0015754059422761202 Epoch 204 MSE: 0.001569023123010993 Epoch 205 MSE: 0.0015626528766006231 Epoch 206 MSE: 0.0015562985790893435 Epoch 207 MSE: 0.0015499631408602 Epoch 208 MSE: 0.0015436491230502725 Epoch 209 MSE: 0.0015373599017038941 Epoch 210 MSE: 0.0015310978051275015 Epoch 211 MSE: 0.001524865860119462 Epoch 212 MSE: 0.001518666627816856 Epoch 213 MSE: 0.0015125020872801542 Epoch 214 MSE: 0.0015063751488924026 Epoch 215 MSE: 0.0015002868603914976 Epoch 216 MSE: 0.0014942394336685538 Epoch 217 MSE: 0.0014882339164614677 Epoch 218 MSE: 0.0014822717057541013 Epoch 219 MSE: 0.0014763528015464544 Epoch 220 MSE: 0.0014704770874232054 Epoch 221 MSE: 0.0014646442141383886 Epoch 222 MSE: 0.001458853716030717 Epoch 223 MSE: 0.0014531039632856846 Epoch 224 MSE: 0.0014473930932581425 Epoch 225 MSE: 0.0014417197089642286 Epoch 226 MSE: 0.0014360809000208974 Epoch 227 MSE: 0.0014304752694442868 Epoch 228 MSE: 0.0014248999068513513 Epoch 229 MSE: 0.0014193531824275851 Epoch 230 MSE: 0.00141383265145123 Epoch 231 MSE: 0.0014083373825997114 Epoch 232 MSE: 0.0014028659788891673 Epoch 233 MSE: 0.0013974178582429886 Epoch 234 MSE: 0.0013919930206611753 Epoch 235 MSE: 0.0013865914661437273 Epoch 236 MSE: 0.001381214358843863 Epoch 237 MSE: 0.001375863328576088 Epoch 238 MSE: 0.0013705401215702295 Epoch 239 MSE: 0.0013652476482093334 Epoch 240 MSE: 0.0013599882367998362 Epoch 241 MSE: 0.0013547652633860707 Epoch 242 MSE: 0.0013495823368430138 Epoch 243 MSE: 0.0013444427167996764 Epoch 244 MSE: 0.0013393503613770008 Epoch 245 MSE: 0.0013343093451112509 Epoch 246 MSE: 0.0013293232768774033 Epoch 247 MSE: 0.0013243958819657564 Epoch 248 MSE: 0.0013195313513278961 Epoch 249 MSE: 0.0013147329445928335 Epoch 250 MSE: 0.0013100046198815107 Epoch 251 MSE: 0.0013053498696535826 Epoch 252 MSE: 0.0013007718371227384 Epoch 253 MSE: 0.001296274014748633 Epoch 254 MSE: 0.0012918595457449555 Epoch 255 MSE: 0.001287531922571361 Epoch 256 MSE: 0.0012832944048568606 Epoch 257 MSE: 0.0012791494373232126 Epoch 258 MSE: 0.0012751005124300718 Epoch 259 MSE: 0.001271150540560484 Epoch 260 MSE: 0.0012673025485128164 Epoch 261 MSE: 0.001263559446670115 Epoch 262 MSE: 0.001259923679754138 Epoch 263 MSE: 0.001256398158147931 Epoch 264 MSE: 0.0012529853265732527 Epoch 265 MSE: 0.0012496879789978266 Epoch 266 MSE: 0.0012465112376958132 Epoch 267 MSE: 0.0012434730306267738 Epoch 268 MSE: 0.0012406742898747325 Epoch 269 MSE: 0.0012388165341690183 Epoch 270 MSE: 0.0012429836206138134 Epoch 271 MSE: 0.0012875766260549426 Epoch 272 MSE: 0.001492623589001596 Epoch 273 MSE: 0.00201231031678617 Epoch 274 MSE: 0.0016974164173007011 Epoch 275 MSE: 0.0012310986639931798 Epoch 276 MSE: 0.001519946032203734 Epoch 277 MSE: 0.0012483973987400532 Epoch 278 MSE: 0.0014362889342010021 Epoch 279 MSE: 0.0012453354429453611 Epoch 280 MSE: 0.0013879925245419145 Epoch 281 MSE: 0.0012415138771757483 Epoch 282 MSE: 0.0013531947042793036 Epoch 283 MSE: 0.0012601300841197371 Epoch 284 MSE: 0.0013007718371227384 Epoch 285 MSE: 0.0012898016721010208 Epoch 286 MSE: 0.001256495830602944 Epoch 287 MSE: 0.0013036596355959773 Epoch 288 MSE: 0.0012414618395268917 Epoch 289 MSE: 0.0012859577545896173 Epoch 290 MSE: 0.0012539169983938336 Epoch 291 MSE: 0.0012540979078039527 Epoch 292 MSE: 0.001267384272068739 Epoch 293 MSE: 0.0012348323361948133 Epoch 294 MSE: 0.0012621035566553473 Epoch 295 MSE: 0.0012356534134596586 Epoch 296 MSE: 0.0012432103976607323 Epoch 297 MSE: 0.0012424476444721222 Epoch 298 MSE: 0.001227591885253787 Epoch 299 MSE: 0.0012415907112881541 Epoch 300 MSE: 0.0012228378327563405 Epoch 301 MSE: 0.0012320202076807618 Epoch 302 MSE: 0.0012245419202372432 Epoch 303 MSE: 0.001221132930368185 Epoch 304 MSE: 0.001225392334163189 Epoch 305 MSE: 0.0012146109947934747 Epoch 306 MSE: 0.0012220368953421712 Epoch 307 MSE: 0.0012130143586546183 Epoch 308 MSE: 0.001216199016198516 Epoch 309 MSE: 0.0012134237913414836 Epoch 310 MSE: 0.0012109134113416076 Epoch 311 MSE: 0.0012131142430007458 Epoch 312 MSE: 0.0012078388826921582 Epoch 313 MSE: 0.0012114611454308033 Epoch 314 MSE: 0.0012066643685102463 Epoch 315 MSE: 0.0012091086246073246 Epoch 316 MSE: 0.0012064219918102026 Epoch 317 MSE: 0.0012068765936419368 Epoch 318 MSE: 0.0012063715839758515 Epoch 319 MSE: 0.001205186010338366 Epoch 320 MSE: 0.0012061332818120718 Epoch 321 MSE: 0.0012041557347401977 Epoch 322 MSE: 0.0012056321138516068 Epoch 323 MSE: 0.0012036605039611459 Epoch 324 MSE: 0.001204974832944572 Epoch 325 MSE: 0.001203469350002706 Epoch 326 MSE: 0.0012043254682794213 Epoch 327 MSE: 0.0012033901875838637 Epoch 328 MSE: 0.0012037695851176977 Epoch 329 MSE: 0.0012033282546326518 Epoch 330 MSE: 0.0012033232487738132 Epoch 331 MSE: 0.0012032492086291313 Epoch 332 MSE: 0.001202977728098631 Epoch 333 MSE: 0.0012031331425532699 Epoch 334 MSE: 0.0012027207994833589 Epoch 335 MSE: 0.0012029799399897456 Epoch 336 MSE: 0.001202528947032988 Epoch 337 MSE: 0.0012028022902086377 Epoch 338 MSE: 0.0012023745803162456 Epoch 339 MSE: 0.0012026189360767603 Epoch 340 MSE: 0.0012022380251437426 Epoch 341 MSE: 0.001202437444590032 Epoch 342 MSE: 0.001202112645842135 Epoch 343 MSE: 0.0012022604933008552 Epoch 344 MSE: 0.0012019947171211243 Epoch 345 MSE: 0.001202090410515666 Epoch 346 MSE: 0.0012018806301057339 Epoch 347 MSE: 0.001201931620016694 Epoch 348 MSE: 0.001201768172904849 Epoch 349 MSE: 0.0012017858680337667 Epoch 350 MSE: 0.0012016582768410444 Epoch 351 MSE: 0.001201652456074953 Epoch 352 MSE: 0.00120155222248286 Epoch 353 MSE: 0.001201529405079782 Epoch 354 MSE: 0.0012014509411528707 Epoch 355 MSE: 0.0012014165986329317 Epoch 356 MSE: 0.0012013546656817198 Epoch 357 MSE: 0.0012013136874884367 Epoch 358 MSE: 0.001201263046823442 Epoch 359 MSE: 0.0012012196239084005 Epoch 360 MSE: 0.0012011764338240027 Epoch 361 MSE: 0.0012011328944936395 Epoch 362 MSE: 0.0012010951759293675 Epoch 363 MSE: 0.0012010523350909352 Epoch 364 MSE: 0.0012010184582322836 Epoch 365 MSE: 0.0012009773636236787 Epoch 366 MSE: 0.0012009465135633945 Epoch 367 MSE: 0.0012009072815999389 Epoch 368 MSE: 0.001200877595692873 Epoch 369 MSE: 0.0012008408084511757 Epoch 370 MSE: 0.001200812985189259 Epoch 371 MSE: 0.0012007781770080328 Epoch 372 MSE: 0.0012007508194074035 Epoch 373 MSE: 0.0012007184559479356 Epoch 374 MSE: 0.0012006916804239154 Epoch 375 MSE: 0.001200660946778953 Epoch 376 MSE: 0.0012006351025775075 Epoch 377 MSE: 0.0012006063479930162 Epoch 378 MSE: 0.0012005807366222143 Epoch 379 MSE: 0.0012005536118522286 Epoch 380 MSE: 0.0012005286989733577 Epoch 381 MSE: 0.0012005026219412684 Epoch 382 MSE: 0.0012004779418930411 Epoch 383 MSE: 0.0012004532618448138 Epoch 384 MSE: 0.0012004293967038393 Epoch 385 MSE: 0.0012004057643935084 Epoch 386 MSE: 0.0012003821320831776 Epoch 387 MSE: 0.001200359663926065 Epoch 388 MSE: 0.0012003362644463778 Epoch 389 MSE: 0.0012003148440271616 Epoch 390 MSE: 0.0012002922594547272 Epoch 391 MSE: 0.0012002714211121202 Epoch 392 MSE: 0.0012002494186162949 Epoch 393 MSE: 0.0012002295115962625 Epoch 394 MSE: 0.0012002078583464026 Epoch 395 MSE: 0.0012001884169876575 Epoch 396 MSE: 0.0012001674622297287 Epoch 397 MSE: 0.0012001482537016273 Epoch 398 MSE: 0.0012001279974356294 Epoch 399 MSE: 0.0012001097202301025 Epoch 400 MSE: 0.0012000901624560356 Epoch 401 MSE: 0.0012000715360045433 Epoch 402 MSE: 0.0012000527931377292 Epoch 403 MSE: 0.0012000349815934896 Epoch 404 MSE: 0.001200016587972641 Epoch 405 MSE: 0.0011999986600130796 Epoch 406 MSE: 0.0011999811977148056 Epoch 407 MSE: 0.0011999636190012097 Epoch 408 MSE: 0.0011999463895335793 Epoch 409 MSE: 0.0011999292764812708 Epoch 410 MSE: 0.001199912279844284 Epoch 411 MSE: 0.0011998956324532628 Epoch 412 MSE: 0.0011998791014775634 Epoch 413 MSE: 0.0011998626869171858 Epoch 414 MSE: 0.0011998465051874518 Epoch 415 MSE: 0.0011998304398730397 Epoch 416 MSE: 0.0011998146073892713 Epoch 417 MSE: 0.001199798658490181 Epoch 418 MSE: 0.0011997830588370562 Epoch 419 MSE: 0.0011997671099379659 Epoch 420 MSE: 0.001199752208776772 Epoch 421 MSE: 0.001199736725538969 Epoch 422 MSE: 0.0011997215915471315 Epoch 423 MSE: 0.0011997065739706159 Epoch 424 MSE: 0.001199691672809422 Epoch 425 MSE: 0.0011996767716482282 Epoch 426 MSE: 0.0011996619869023561 Epoch 427 MSE: 0.0011996474349871278 Epoch 428 MSE: 0.0011996329994872212 Epoch 429 MSE: 0.0011996185639873147 Epoch 430 MSE: 0.0011996041284874082 Epoch 431 MSE: 0.0011995899258181453 Epoch 432 MSE: 0.0011995757231488824 Epoch 433 MSE: 0.0011995616368949413 Epoch 434 MSE: 0.001199547667056322 Epoch 435 MSE: 0.0011995336972177029 Epoch 436 MSE: 0.0011995199602097273 Epoch 437 MSE: 0.0011995062232017517 Epoch 438 MSE: 0.0011994928354397416 Epoch 439 MSE: 0.0011994787491858006 Epoch 440 MSE: 0.0011994653614237905 Epoch 441 MSE: 0.0011994517408311367 Epoch 442 MSE: 0.0011994385858997703 Epoch 443 MSE: 0.0011994248488917947 Epoch 444 MSE: 0.00119941181037575 Epoch 445 MSE: 0.00119939842261374 Epoch 446 MSE: 0.0011993851512670517 Epoch 447 MSE: 0.0011993723455816507 Epoch 448 MSE: 0.0011993593070656061 Epoch 449 MSE: 0.0011993460357189178 Epoch 450 MSE: 0.0011993328807875514 Epoch 451 MSE: 0.001199320307932794 Epoch 452 MSE: 0.0011993072694167495 Epoch 453 MSE: 0.0011992944637313485 Epoch 454 MSE: 0.0011992815416306257 Epoch 455 MSE: 0.0011992689687758684 Epoch 456 MSE: 0.0011992561630904675 Epoch 457 MSE: 0.001199243706651032 Epoch 458 MSE: 0.0011992307845503092 Epoch 459 MSE: 0.0011992183281108737 Epoch 460 MSE: 0.0011992056388407946 Epoch 461 MSE: 0.001199193182401359 Epoch 462 MSE: 0.00119918049313128 Epoch 463 MSE: 0.0011991680366918445 Epoch 464 MSE: 0.0011991556966677308 Epoch 465 MSE: 0.0011991433566436172 Epoch 466 MSE: 0.0011991310166195035 Epoch 467 MSE: 0.001199118560180068 Epoch 468 MSE: 0.0011991062201559544 Epoch 469 MSE: 0.0011990941129624844 Epoch 470 MSE: 0.001199081540107727 Epoch 471 MSE: 0.0011990695493295789 Epoch 472 MSE: 0.0011990572093054652 Epoch 473 MSE: 0.0011990451021119952 Epoch 474 MSE: 0.0011990328785032034 Epoch 475 MSE: 0.0011990207713097334 Epoch 476 MSE: 0.0011990084312856197 Epoch 477 MSE: 0.0011989963240921497 Epoch 478 MSE: 0.0011989843333140016 Epoch 479 MSE: 0.0011989721097052097 Epoch 480 MSE: 0.0011989601189270616 Epoch 481 MSE: 0.0011989480117335916 Epoch 482 MSE: 0.0011989361373707652 Epoch 483 MSE: 0.001198924146592617 Epoch 484 MSE: 0.0011989121558144689 Epoch 485 MSE: 0.0011989000486209989 Epoch 486 MSE: 0.0011988877085968852 Epoch 487 MSE: 0.0011988759506493807 Epoch 488 MSE: 0.0011988637270405889 Epoch 489 MSE: 0.0011988519690930843 Epoch 490 MSE: 0.0011988397454842925 Epoch 491 MSE: 0.0011988281039521098 Epoch 492 MSE: 0.0011988162295892835 Epoch 493 MSE: 0.0011988041223958135 Epoch 494 MSE: 0.0011987922480329871 Epoch 495 MSE: 0.0011987803736701608 Epoch 496 MSE: 0.0011987684993073344 Epoch 497 MSE: 0.0011987563921138644 Epoch 498 MSE: 0.001198744517751038 Epoch 499 MSE: 0.0011987326433882117 Training time: 11.960481882095337
predict = pd.DataFrame(scaler.inverse_transform(y_train_pred.cpu().detach().numpy()))
original = pd.DataFrame(scaler.inverse_transform(y_train_lstm.cpu().detach().numpy()))
sns.set_style("darkgrid")
fig = plt.figure()
fig.subplots_adjust(hspace=0.2, wspace=0.2)
plt.subplot(1, 2, 1)
ax = sns.lineplot(x = original.index, y = original[0], label="Data", color='royalblue')
ax = sns.lineplot(x = predict.index, y = predict[0], label="Training Prediction (LSTM)", color='tomato')
ax.set_title('Stock price', size = 14, fontweight='bold')
ax.set_xlabel("Days", size = 14)
ax.set_ylabel("Cost (USD)", size = 14)
ax.set_xticklabels('', size=10)
plt.subplot(1, 2, 2)
ax = sns.lineplot(data=hist, color='royalblue')
ax.set_xlabel("Epoch", size = 14)
ax.set_ylabel("Loss", size = 14)
ax.set_title("Training Loss", size = 14, fontweight='bold')
fig.set_figheight(6)
fig.set_figwidth(16)
device = torch.device("cpu")
model.to(device)
x_train = x_train.to(device)
y_train_pred = y_train_pred.to(device)
y_train_lstm = y_train_lstm.to(device)
# make predictions
y_test_pred = model(x_test).cpu()
# invert predictions
y_train_pred = scaler.inverse_transform(y_train_pred.detach().numpy())
y_train = scaler.inverse_transform(y_train_lstm.detach().numpy())
y_test_pred = scaler.inverse_transform(y_test_pred.detach().numpy())
y_test = scaler.inverse_transform(y_test_lstm.detach().numpy())
# calculate root mean squared error
trainScore = math.sqrt(mean_squared_error(y_train[:,0], y_train_pred[:,0]))
print('Train Score: %.2f RMSE' % (trainScore))
testScore = math.sqrt(mean_squared_error(y_test[:,0], y_test_pred[:,0]))
print('Test Score: %.2f RMSE' % (testScore))
lstm.append(trainScore)
lstm.append(testScore)
lstm.append(training_time)
Train Score: 0.51 RMSE Test Score: 0.93 RMSE
# shift train predictions for plotting
trainPredictPlot = np.empty_like(price)
trainPredictPlot[:, :] = np.nan
trainPredictPlot[lookback:len(y_train_pred)+lookback, :] = y_train_pred
# shift test predictions for plotting
testPredictPlot = np.empty_like(price)
testPredictPlot[:, :] = np.nan
testPredictPlot[len(y_train_pred)+lookback-1:len(price)-1, :] = y_test_pred
original = scaler.inverse_transform(price['Adj Close'].values.reshape(-1,1))
predictions = np.append(trainPredictPlot, testPredictPlot, axis=1)
predictions = np.append(predictions, original, axis=1)
result = pd.DataFrame(predictions)
fig = go.Figure()
fig.add_trace(go.Scatter(go.Scatter(x=df["Date"], y=result[0],
mode='lines',
name='Train prediction')))
fig.add_trace(go.Scatter(x=df["Date"], y=result[1],
mode='lines',
name='Test prediction'))
fig.add_trace(go.Scatter(go.Scatter(x=df["Date"], y=result[2],
mode='lines',
name='Actual Value')))
fig.update_xaxes(
rangeslider_visible=True,
rangeselector=dict(
buttons=list([
dict(count=1, label="1m", step="month", stepmode="backward"),
dict(count=6, label="6m", step="month", stepmode="backward"),
dict(count=1, label="YTD", step="year", stepmode="todate"),
dict(count=1, label="1y", step="year", stepmode="backward"),
dict(count=1, label="2y", step="year", stepmode="backward"),
dict(step="all")
]),
font=dict(family='Rockwell',color='black'),
)
)
fig.show()
class GRU(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
super(GRU, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).requires_grad_().to(device)
out, (hn) = self.gru(x, (h0.detach()))
out = self.fc(out[:, -1, :])
return out
model = GRU(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=num_layers)
criterion = torch.nn.MSELoss(reduction='mean')
optimiser = torch.optim.Adam(model.parameters(), lr=0.01)
model = model.to(device)
criterion.to(device)
x_train = x_train.to(device)
y_train_gru = y_train_gru.to(device)
hist = np.zeros(num_epochs)
start_time = time.time()
gru = []
for t in range(num_epochs):
y_train_pred = model(x_train)
loss = criterion(y_train_pred, y_train_gru)
print("Epoch ", t, "MSE: ", loss.item())
hist[t] = loss.item()
optimiser.zero_grad()
loss.backward()
optimiser.step()
training_time = time.time()-start_time
print("Training time: {}".format(training_time))
Epoch 0 MSE: 0.09333141893148422 Epoch 1 MSE: 0.14284326136112213 Epoch 2 MSE: 0.04148746281862259 Epoch 3 MSE: 0.08071506023406982 Epoch 4 MSE: 0.031027598306536674 Epoch 5 MSE: 0.011289249174296856 Epoch 6 MSE: 0.019751261919736862 Epoch 7 MSE: 0.022235073149204254 Epoch 8 MSE: 0.010761172510683537 Epoch 9 MSE: 0.005481211934238672 Epoch 10 MSE: 0.007945975288748741 Epoch 11 MSE: 0.01044232677668333 Epoch 12 MSE: 0.011669597588479519 Epoch 13 MSE: 0.01093620341271162 Epoch 14 MSE: 0.00786434393376112 Epoch 15 MSE: 0.0041869031265378 Epoch 16 MSE: 0.0030486483592540026 Epoch 17 MSE: 0.005431350786238909 Epoch 18 MSE: 0.007107957731932402 Epoch 19 MSE: 0.005823053885251284 Epoch 20 MSE: 0.004497723188251257 Epoch 21 MSE: 0.0043601300567388535 Epoch 22 MSE: 0.004275926388800144 Epoch 23 MSE: 0.003639059141278267 Epoch 24 MSE: 0.0031692786142230034 Epoch 25 MSE: 0.0037069960962980986 Epoch 26 MSE: 0.004504668992012739 Epoch 27 MSE: 0.004151158966124058 Epoch 28 MSE: 0.0032008399721235037 Epoch 29 MSE: 0.0028251002077013254 Epoch 30 MSE: 0.0029367287643253803 Epoch 31 MSE: 0.0029540611431002617 Epoch 32 MSE: 0.0027779985684901476 Epoch 33 MSE: 0.002738407114520669 Epoch 34 MSE: 0.0029480746015906334 Epoch 35 MSE: 0.0030101609881967306 Epoch 36 MSE: 0.0026732529513537884 Epoch 37 MSE: 0.002299028914421797 Epoch 38 MSE: 0.002244044793769717 Epoch 39 MSE: 0.002394381444901228 Epoch 40 MSE: 0.0024728148709982634 Epoch 41 MSE: 0.002410072134807706 Epoch 42 MSE: 0.002331031020730734 Epoch 43 MSE: 0.002305177040398121 Epoch 44 MSE: 0.0022459591273218393 Epoch 45 MSE: 0.0021093646064400673 Epoch 46 MSE: 0.0020122758578509092 Epoch 47 MSE: 0.002040587831288576 Epoch 48 MSE: 0.002114779083058238 Epoch 49 MSE: 0.002115062205120921 Epoch 50 MSE: 0.0020325342193245888 Epoch 51 MSE: 0.0019503985531628132 Epoch 52 MSE: 0.0019168432336300611 Epoch 53 MSE: 0.0018990904791280627 Epoch 54 MSE: 0.0018699258798733354 Epoch 55 MSE: 0.0018547893268987536 Epoch 56 MSE: 0.0018639949848875403 Epoch 57 MSE: 0.001859020907431841 Epoch 58 MSE: 0.0018133263802155852 Epoch 59 MSE: 0.0017564125591889024 Epoch 60 MSE: 0.0017312460113316774 Epoch 61 MSE: 0.0017338049365207553 Epoch 62 MSE: 0.0017293855780735612 Epoch 63 MSE: 0.001710184384137392 Epoch 64 MSE: 0.0016924866940826178 Epoch 65 MSE: 0.0016767147462815046 Epoch 66 MSE: 0.0016523788217455149 Epoch 67 MSE: 0.0016270113410428166 Epoch 68 MSE: 0.0016172248870134354 Epoch 69 MSE: 0.0016179620288312435 Epoch 70 MSE: 0.0016088931588456035 Epoch 71 MSE: 0.0015877644764259458 Epoch 72 MSE: 0.0015694905305281281 Epoch 73 MSE: 0.001558594754897058 Epoch 74 MSE: 0.0015478333225473762 Epoch 75 MSE: 0.0015364120481535792 Epoch 76 MSE: 0.001528959721326828 Epoch 77 MSE: 0.0015219537308439612 Epoch 78 MSE: 0.0015088655054569244 Epoch 79 MSE: 0.0014938311651349068 Epoch 80 MSE: 0.0014847107231616974 Epoch 81 MSE: 0.0014796123141422868 Epoch 82 MSE: 0.001472295611165464 Epoch 83 MSE: 0.0014628851786255836 Epoch 84 MSE: 0.0014543880242854357 Epoch 85 MSE: 0.0014460552483797073 Epoch 86 MSE: 0.0014371885918080807 Epoch 87 MSE: 0.001430271309800446 Epoch 88 MSE: 0.0014254908310249448 Epoch 89 MSE: 0.0014194028917700052 Epoch 90 MSE: 0.0014112520730122924 Epoch 91 MSE: 0.0014039549278095365 Epoch 92 MSE: 0.0013984640827402472 Epoch 93 MSE: 0.0013930745190009475 Epoch 94 MSE: 0.0013874280266463757 Epoch 95 MSE: 0.0013821744360029697 Epoch 96 MSE: 0.0013766869669780135 Epoch 97 MSE: 0.0013706929748877883 Epoch 98 MSE: 0.0013654601061716676 Epoch 99 MSE: 0.0013612997718155384 Epoch 100 MSE: 0.0013568734284490347 Epoch 101 MSE: 0.0013518015621230006 Epoch 102 MSE: 0.0013470025733113289 Epoch 103 MSE: 0.0013426828663796186 Epoch 104 MSE: 0.0013384444173425436 Epoch 105 MSE: 0.0013343938626348972 Epoch 106 MSE: 0.0013305442407727242 Epoch 107 MSE: 0.0013264798326417804 Epoch 108 MSE: 0.001322334399446845 Epoch 109 MSE: 0.0013186425203457475 Epoch 110 MSE: 0.001315259956754744 Epoch 111 MSE: 0.0013117281487211585 Epoch 112 MSE: 0.0013081388315185905 Epoch 113 MSE: 0.0013047363609075546 Epoch 114 MSE: 0.0013014484429731965 Epoch 115 MSE: 0.001298255636356771 Epoch 116 MSE: 0.0012952324468642473 Epoch 117 MSE: 0.001292238594032824 Epoch 118 MSE: 0.0012891854858025908 Epoch 119 MSE: 0.0012862655567005277 Epoch 120 MSE: 0.0012835602974519134 Epoch 121 MSE: 0.0012809012550860643 Epoch 122 MSE: 0.001278234412893653 Epoch 123 MSE: 0.0012756555806845427 Epoch 124 MSE: 0.0012731729075312614 Epoch 125 MSE: 0.0012707647401839495 Epoch 126 MSE: 0.0012684576213359833 Epoch 127 MSE: 0.0012662160443142056 Epoch 128 MSE: 0.0012639909982681274 Epoch 129 MSE: 0.0012618412729352713 Epoch 130 MSE: 0.0012598125031217933 Epoch 131 MSE: 0.0012578460155054927 Epoch 132 MSE: 0.001255908515304327 Epoch 133 MSE: 0.0012540369061753154 Epoch 134 MSE: 0.0012522421311587095 Epoch 135 MSE: 0.0012505091726779938 Epoch 136 MSE: 0.0012488430365920067 Epoch 137 MSE: 0.001247233827598393 Epoch 138 MSE: 0.0012456653639674187 Epoch 139 MSE: 0.001244157087057829 Epoch 140 MSE: 0.001242722850292921 Epoch 141 MSE: 0.0012413384392857552 Epoch 142 MSE: 0.0012399920960888267 Epoch 143 MSE: 0.0012386994203552604 Epoch 144 MSE: 0.0012374641373753548 Epoch 145 MSE: 0.0012362755369395018 Epoch 146 MSE: 0.0012351330369710922 Epoch 147 MSE: 0.0012340336106717587 Epoch 148 MSE: 0.0012329750461503863 Epoch 149 MSE: 0.0012319623492658138 Epoch 150 MSE: 0.0012309955200180411 Epoch 151 MSE: 0.0012300637317821383 Epoch 152 MSE: 0.001229165936820209 Epoch 153 MSE: 0.0012283088872209191 Epoch 154 MSE: 0.0012274902546778321 Epoch 155 MSE: 0.001226703287102282 Epoch 156 MSE: 0.001225948566570878 Epoch 157 MSE: 0.0012252257438376546 Epoch 158 MSE: 0.0012245336547493935 Epoch 159 MSE: 0.0012238724157214165 Epoch 160 MSE: 0.0012232392327859998 Epoch 161 MSE: 0.0012226310791447759 Epoch 162 MSE: 0.0012220497010275722 Epoch 163 MSE: 0.0012214949820190668 Epoch 164 MSE: 0.0012209640117362142 Epoch 165 MSE: 0.0012204546947032213 Epoch 166 MSE: 0.001219967845827341 Epoch 167 MSE: 0.001219503115862608 Epoch 168 MSE: 0.001219058409333229 Epoch 169 MSE: 0.001218633260577917 Epoch 170 MSE: 0.0012182266218587756 Epoch 171 MSE: 0.0012178380275145173 Epoch 172 MSE: 0.0012174672447144985 Epoch 173 MSE: 0.0012171129928901792 Epoch 174 MSE: 0.0012167737586423755 Epoch 175 MSE: 0.0012164500076323748 Epoch 176 MSE: 0.001216141157783568 Epoch 177 MSE: 0.0012158462777733803 Epoch 178 MSE: 0.0012155642034485936 Epoch 179 MSE: 0.0012152951676398516 Epoch 180 MSE: 0.0012150387046858668 Epoch 181 MSE: 0.0012147933011874557 Epoch 182 MSE: 0.0012145593063905835 Epoch 183 MSE: 0.0012143359053879976 Epoch 184 MSE: 0.0012141227489337325 Epoch 185 MSE: 0.001213919254951179 Epoch 186 MSE: 0.0012137250741943717 Epoch 187 MSE: 0.0012135393917560577 Epoch 188 MSE: 0.0012133622076362371 Epoch 189 MSE: 0.0012131932890042663 Epoch 190 MSE: 0.0012130315881222486 Epoch 191 MSE: 0.0012128771049901843 Epoch 192 MSE: 0.0012127296067774296 Epoch 193 MSE: 0.0012125886278226972 Epoch 194 MSE: 0.001212453586049378 Epoch 195 MSE: 0.0012123244814574718 Epoch 196 MSE: 0.0012122004991397262 Epoch 197 MSE: 0.00121208222117275 Epoch 198 MSE: 0.0012119687162339687 Epoch 199 MSE: 0.0012118594022467732 Epoch 200 MSE: 0.0012117548612877727 Epoch 201 MSE: 0.0012116542784497142 Epoch 202 MSE: 0.0012115574209019542 Epoch 203 MSE: 0.0012114644050598145 Epoch 204 MSE: 0.001211374532431364 Epoch 205 MSE: 0.0012112879194319248 Epoch 206 MSE: 0.001211204333230853 Epoch 207 MSE: 0.0012111234245821834 Epoch 208 MSE: 0.0012110451934859157 Epoch 209 MSE: 0.0012109694071114063 Epoch 210 MSE: 0.0012108958326280117 Epoch 211 MSE: 0.0012108244700357318 Epoch 212 MSE: 0.0012107548536732793 Epoch 213 MSE: 0.0012106872163712978 Epoch 214 MSE: 0.001210621208883822 Epoch 215 MSE: 0.0012105568312108517 Epoch 216 MSE: 0.0012104939669370651 Epoch 217 MSE: 0.0012104324996471405 Epoch 218 MSE: 0.0012103721965104342 Epoch 219 MSE: 0.001210313057526946 Epoch 220 MSE: 0.0012102550826966763 Epoch 221 MSE: 0.0012101982720196247 Epoch 222 MSE: 0.001210142276249826 Epoch 223 MSE: 0.0012100872118026018 Epoch 224 MSE: 0.0012100327294319868 Epoch 225 MSE: 0.0012099792947992682 Epoch 226 MSE: 0.0012099264422431588 Epoch 227 MSE: 0.0012098741717636585 Epoch 228 MSE: 0.0012098224833607674 Epoch 229 MSE: 0.0012097714934498072 Epoch 230 MSE: 0.0012097208527848125 Epoch 231 MSE: 0.0012096707941964269 Epoch 232 MSE: 0.001209620968438685 Epoch 233 MSE: 0.0012095717247575521 Epoch 234 MSE: 0.0012095228303223848 Epoch 235 MSE: 0.0012094741687178612 Epoch 236 MSE: 0.001209425856359303 Epoch 237 MSE: 0.0012093778932467103 Epoch 238 MSE: 0.0012093300465494394 Epoch 239 MSE: 0.0012092826655134559 Epoch 240 MSE: 0.0012092354008927941 Epoch 241 MSE: 0.001209188369102776 Epoch 242 MSE: 0.001209141337312758 Epoch 243 MSE: 0.0012090946547687054 Epoch 244 MSE: 0.0012090482050552964 Epoch 245 MSE: 0.0012090018717572093 Epoch 246 MSE: 0.0012089555384591222 Epoch 247 MSE: 0.0012089093215763569 Epoch 248 MSE: 0.0012088633375242352 Epoch 249 MSE: 0.0012088174698874354 Epoch 250 MSE: 0.0012087714858353138 Epoch 251 MSE: 0.001208725618198514 Epoch 252 MSE: 0.0012086800998076797 Epoch 253 MSE: 0.0012086343485862017 Epoch 254 MSE: 0.0012085885973647237 Epoch 255 MSE: 0.001208543311804533 Epoch 256 MSE: 0.0012084976769983768 Epoch 257 MSE: 0.0012084522750228643 Epoch 258 MSE: 0.00120840675663203 Epoch 259 MSE: 0.0012083611218258739 Epoch 260 MSE: 0.0012083158362656832 Epoch 261 MSE: 0.0012082704342901707 Epoch 262 MSE: 0.00120822514872998 Epoch 263 MSE: 0.0012081797467544675 Epoch 264 MSE: 0.0012081342283636332 Epoch 265 MSE: 0.0012080888263881207 Epoch 266 MSE: 0.0012080433079972863 Epoch 267 MSE: 0.0012079979060217738 Epoch 268 MSE: 0.0012079525040462613 Epoch 269 MSE: 0.0012079068692401052 Epoch 270 MSE: 0.001207861234433949 Epoch 271 MSE: 0.0012078157160431147 Epoch 272 MSE: 0.0012077700812369585 Epoch 273 MSE: 0.0012077243300154805 Epoch 274 MSE: 0.0012076786952093244 Epoch 275 MSE: 0.0012076328275725245 Epoch 276 MSE: 0.0012075870763510466 Epoch 277 MSE: 0.001207541092298925 Epoch 278 MSE: 0.0012074949918314815 Epoch 279 MSE: 0.0012074491241946816 Epoch 280 MSE: 0.0012074029073119164 Epoch 281 MSE: 0.001207356690429151 Epoch 282 MSE: 0.001207310357131064 Epoch 283 MSE: 0.0012072641402482986 Epoch 284 MSE: 0.0012072176905348897 Epoch 285 MSE: 0.0012071712408214808 Epoch 286 MSE: 0.0012071247911080718 Epoch 287 MSE: 0.0012070779921486974 Epoch 288 MSE: 0.0012070313096046448 Epoch 289 MSE: 0.0012069842778146267 Epoch 290 MSE: 0.0012069374788552523 Epoch 291 MSE: 0.0012068904470652342 Epoch 292 MSE: 0.001206843531690538 Epoch 293 MSE: 0.0012067961506545544 Epoch 294 MSE: 0.0012067487696185708 Epoch 295 MSE: 0.0012067012721672654 Epoch 296 MSE: 0.0012066536583006382 Epoch 297 MSE: 0.001206606044434011 Epoch 298 MSE: 0.0012065580813214183 Epoch 299 MSE: 0.0012065102346241474 Epoch 300 MSE: 0.0012064622715115547 Epoch 301 MSE: 0.0012064140755683184 Epoch 302 MSE: 0.0012063659960404038 Epoch 303 MSE: 0.0012063175672665238 Epoch 304 MSE: 0.001206269022077322 Epoch 305 MSE: 0.0012062203604727983 Epoch 306 MSE: 0.0012061716988682747 Epoch 307 MSE: 0.0012061228044331074 Epoch 308 MSE: 0.0012060737935826182 Epoch 309 MSE: 0.0012060245499014854 Epoch 310 MSE: 0.0012059754226356745 Epoch 311 MSE: 0.001205925946123898 Epoch 312 MSE: 0.0012058764696121216 Epoch 313 MSE: 0.0012058268766850233 Epoch 314 MSE: 0.0012057770509272814 Epoch 315 MSE: 0.0012057271087542176 Epoch 316 MSE: 0.001205677050165832 Epoch 317 MSE: 0.0012056267587468028 Epoch 318 MSE: 0.0012055765837430954 Epoch 319 MSE: 0.0012055259430781007 Epoch 320 MSE: 0.0012054751859977841 Epoch 321 MSE: 0.0012054244289174676 Epoch 322 MSE: 0.0012053734390065074 Epoch 323 MSE: 0.0012053223326802254 Epoch 324 MSE: 0.0012052709935232997 Epoch 325 MSE: 0.0012052195379510522 Epoch 326 MSE: 0.0012051679659634829 Epoch 327 MSE: 0.0012051162775605917 Epoch 328 MSE: 0.0012050643563270569 Epoch 329 MSE: 0.0012050123186782002 Epoch 330 MSE: 0.0012049601646140218 Epoch 331 MSE: 0.0012049076613038778 Epoch 332 MSE: 0.0012048552744090557 Epoch 333 MSE: 0.001204802538268268 Epoch 334 MSE: 0.0012047495692968369 Epoch 335 MSE: 0.0012046964839100838 Epoch 336 MSE: 0.0012046433985233307 Epoch 337 MSE: 0.0012045899638906121 Epoch 338 MSE: 0.0012045365292578936 Epoch 339 MSE: 0.0012044827453792095 Epoch 340 MSE: 0.0012044287286698818 Epoch 341 MSE: 0.001204374828375876 Epoch 342 MSE: 0.0012043205788359046 Epoch 343 MSE: 0.0012042660964652896 Epoch 344 MSE: 0.0012042114976793528 Epoch 345 MSE: 0.001204156898893416 Epoch 346 MSE: 0.0012041018344461918 Epoch 347 MSE: 0.0012040467699989676 Epoch 348 MSE: 0.001203991356305778 Epoch 349 MSE: 0.0012039359426125884 Epoch 350 MSE: 0.001203880412504077 Epoch 351 MSE: 0.0012038244167342782 Epoch 352 MSE: 0.0012037684209644794 Epoch 353 MSE: 0.0012037124251946807 Epoch 354 MSE: 0.0012036558473482728 Epoch 355 MSE: 0.0012035993859171867 Epoch 356 MSE: 0.0012035425752401352 Epoch 357 MSE: 0.0012034856481477618 Epoch 358 MSE: 0.0012034286046400666 Epoch 359 MSE: 0.001203371211886406 Epoch 360 MSE: 0.0012033137027174234 Epoch 361 MSE: 0.001203256193548441 Epoch 362 MSE: 0.001203198335133493 Epoch 363 MSE: 0.0012031401274725795 Epoch 364 MSE: 0.0012030820362269878 Epoch 365 MSE: 0.0012030235957354307 Epoch 366 MSE: 0.0012029648059979081 Epoch 367 MSE: 0.0012029061326757073 Epoch 368 MSE: 0.0012028469936922193 Epoch 369 MSE: 0.0012027877382934093 Epoch 370 MSE: 0.0012027284828945994 Epoch 371 MSE: 0.0012026689946651459 Epoch 372 MSE: 0.0012026091571897268 Epoch 373 MSE: 0.0012025493197143078 Epoch 374 MSE: 0.0012024890165776014 Epoch 375 MSE: 0.001202428713440895 Epoch 376 MSE: 0.0012023680610582232 Epoch 377 MSE: 0.0012023074086755514 Epoch 378 MSE: 0.001202246407046914 Epoch 379 MSE: 0.001202185289002955 Epoch 380 MSE: 0.0012021239381283522 Epoch 381 MSE: 0.0012020623544231057 Epoch 382 MSE: 0.0012020006543025374 Epoch 383 MSE: 0.0012019387213513255 Epoch 384 MSE: 0.00120187655556947 Epoch 385 MSE: 0.0012018142733722925 Epoch 386 MSE: 0.0012017517583444715 Epoch 387 MSE: 0.0012016890104860067 Epoch 388 MSE: 0.0012016261462122202 Epoch 389 MSE: 0.0012015631655231118 Epoch 390 MSE: 0.0012014999520033598 Epoch 391 MSE: 0.0012014363892376423 Epoch 392 MSE: 0.001201372710056603 Epoch 393 MSE: 0.0012013086816295981 Epoch 394 MSE: 0.0012012446532025933 Epoch 395 MSE: 0.001201180275529623 Epoch 396 MSE: 0.0012011158978566527 Epoch 397 MSE: 0.0012010512873530388 Epoch 398 MSE: 0.0012009863276034594 Epoch 399 MSE: 0.00120092136785388 Epoch 400 MSE: 0.001200856058858335 Epoch 401 MSE: 0.00120079074986279 Epoch 402 MSE: 0.0012007249752059579 Epoch 403 MSE: 0.0012006593169644475 Epoch 404 MSE: 0.0012005933094769716 Epoch 405 MSE: 0.001200527185574174 Epoch 406 MSE: 0.0012004607124254107 Epoch 407 MSE: 0.0012003942392766476 Epoch 408 MSE: 0.001200327416881919 Epoch 409 MSE: 0.0012002605944871902 Epoch 410 MSE: 0.001200193422846496 Epoch 411 MSE: 0.001200126251205802 Epoch 412 MSE: 0.0012000587303191423 Epoch 413 MSE: 0.001199991093017161 Epoch 414 MSE: 0.0011999233392998576 Epoch 415 MSE: 0.0011998553527519107 Epoch 416 MSE: 0.0011997870169579983 Epoch 417 MSE: 0.0011997186811640859 Epoch 418 MSE: 0.0011996503453701735 Epoch 419 MSE: 0.0011995816603302956 Epoch 420 MSE: 0.001199512742459774 Epoch 421 MSE: 0.0011994438245892525 Epoch 422 MSE: 0.0011993745574727654 Epoch 423 MSE: 0.0011993051739409566 Epoch 424 MSE: 0.001199235673993826 Epoch 425 MSE: 0.0011991660576313734 Epoch 426 MSE: 0.001199096324853599 Epoch 427 MSE: 0.001199026475660503 Epoch 428 MSE: 0.0011989561608061194 Epoch 429 MSE: 0.0011988860787823796 Epoch 430 MSE: 0.0011988157639279962 Epoch 431 MSE: 0.0011987450998276472 Epoch 432 MSE: 0.0011986744357272983 Epoch 433 MSE: 0.0011986036552116275 Epoch 434 MSE: 0.0011985327582806349 Epoch 435 MSE: 0.0011984616285189986 Epoch 436 MSE: 0.0011983903823420405 Epoch 437 MSE: 0.0011983191361650825 Epoch 438 MSE: 0.0011982476571574807 Epoch 439 MSE: 0.001198176178149879 Epoch 440 MSE: 0.0011981045827269554 Epoch 441 MSE: 0.0011980326380580664 Epoch 442 MSE: 0.0011979610426351428 Epoch 443 MSE: 0.0011978887487202883 Epoch 444 MSE: 0.0011978168040513992 Epoch 445 MSE: 0.0011977446265518665 Epoch 446 MSE: 0.001197672332637012 Epoch 447 MSE: 0.0011975999223068357 Epoch 448 MSE: 0.001197527744807303 Epoch 449 MSE: 0.0011974552180618048 Epoch 450 MSE: 0.0011973825749009848 Epoch 451 MSE: 0.001197309698909521 Epoch 452 MSE: 0.001197237055748701 Epoch 453 MSE: 0.001197164412587881 Epoch 454 MSE: 0.0011970914201810956 Epoch 455 MSE: 0.0011970184277743101 Epoch 456 MSE: 0.00119694578461349 Epoch 457 MSE: 0.0011968727922067046 Epoch 458 MSE: 0.0011967996833845973 Epoch 459 MSE: 0.0011967268073931336 Epoch 460 MSE: 0.0011966536985710263 Epoch 461 MSE: 0.0011965807061642408 Epoch 462 MSE: 0.0011965074809268117 Epoch 463 MSE: 0.001196434604935348 Epoch 464 MSE: 0.001196361379697919 Epoch 465 MSE: 0.0011962883872911334 Epoch 466 MSE: 0.0011962151620537043 Epoch 467 MSE: 0.0011961422860622406 Epoch 468 MSE: 0.0011960692936554551 Epoch 469 MSE: 0.0011959961848333478 Epoch 470 MSE: 0.0011959233088418841 Epoch 471 MSE: 0.0011958503164350986 Epoch 472 MSE: 0.001195777440443635 Epoch 473 MSE: 0.0011957046808674932 Epoch 474 MSE: 0.0011956319212913513 Epoch 475 MSE: 0.0011955592781305313 Epoch 476 MSE: 0.0011954867513850331 Epoch 477 MSE: 0.001195414224639535 Epoch 478 MSE: 0.0011953418143093586 Epoch 479 MSE: 0.0011952694039791822 Epoch 480 MSE: 0.0011951973428949714 Epoch 481 MSE: 0.0011951252818107605 Epoch 482 MSE: 0.0011950533371418715 Epoch 483 MSE: 0.0011949815088883042 Epoch 484 MSE: 0.0011949097970500588 Epoch 485 MSE: 0.001194838434457779 Epoch 486 MSE: 0.0011947669554501772 Epoch 487 MSE: 0.0011946957092732191 Epoch 488 MSE: 0.0011946248123422265 Epoch 489 MSE: 0.001194553915411234 Epoch 490 MSE: 0.001194483251310885 Epoch 491 MSE: 0.0011944128200411797 Epoch 492 MSE: 0.0011943423887714744 Epoch 493 MSE: 0.0011942726559937 Epoch 494 MSE: 0.0011942025739699602 Epoch 495 MSE: 0.0011941331904381514 Epoch 496 MSE: 0.0011940636904910207 Epoch 497 MSE: 0.0011939946562051773 Epoch 498 MSE: 0.0011939258547499776 Epoch 499 MSE: 0.0011938571697100997 Training time: 19.867424726486206
predict = pd.DataFrame(scaler.inverse_transform(y_train_pred.cpu().detach().numpy()))
original = pd.DataFrame(scaler.inverse_transform(y_train_gru.cpu().detach().numpy()))
sns.set_style("darkgrid")
fig = plt.figure()
fig.subplots_adjust(hspace=0.2, wspace=0.2)
plt.subplot(1, 2, 1)
ax = sns.lineplot(x = original.index, y = original[0], label="Data", color='royalblue')
ax = sns.lineplot(x = predict.index, y = predict[0], label="Training Prediction (GRU)", color='tomato')
ax.set_title('Stock price', size = 14, fontweight='bold')
ax.set_xlabel("Days", size = 14)
ax.set_ylabel("Cost (USD)", size = 14)
ax.set_xticklabels('', size=10)
plt.subplot(1, 2, 2)
ax = sns.lineplot(data=hist, color='royalblue')
ax.set_xlabel("Epoch", size = 14)
ax.set_ylabel("Loss", size = 14)
ax.set_title("Training Loss", size = 14, fontweight='bold')
fig.set_figheight(6)
fig.set_figwidth(16)
device = torch.device("cpu")
model.to(device)
x_train = x_train.to(device)
y_train_pred = y_train_pred.to(device)
y_train_gru = y_train_gru.to(device)
# make predictions
y_test_pred = model(x_test).cpu()
# invert predictions
y_train_pred = scaler.inverse_transform(y_train_pred.detach().numpy())
y_train = scaler.inverse_transform(y_train_gru.detach().numpy())
y_test_pred = scaler.inverse_transform(y_test_pred.detach().numpy())
y_test = scaler.inverse_transform(y_test_gru.detach().numpy())
# calculate root mean squared error
trainScore = math.sqrt(mean_squared_error(y_train[:,0], y_train_pred[:,0]))
print('Train Score: %.2f RMSE' % (trainScore))
testScore = math.sqrt(mean_squared_error(y_test[:,0], y_test_pred[:,0]))
print('Test Score: %.2f RMSE' % (testScore))
gru.append(trainScore)
gru.append(testScore)
gru.append(training_time)
Train Score: 0.51 RMSE Test Score: 0.94 RMSE
# shift train predictions for plotting
trainPredictPlot = np.empty_like(price)
trainPredictPlot[:, :] = np.nan
trainPredictPlot[lookback:len(y_train_pred)+lookback, :] = y_train_pred
# shift test predictions for plotting
testPredictPlot = np.empty_like(price)
testPredictPlot[:, :] = np.nan
testPredictPlot[len(y_train_pred)+lookback-1:len(price)-1, :] = y_test_pred
original = scaler.inverse_transform(price['Adj Close'].values.reshape(-1,1))
predictions = np.append(trainPredictPlot, testPredictPlot, axis=1)
predictions = np.append(predictions, original, axis=1)
result = pd.DataFrame(predictions)
fig = go.Figure()
fig.add_trace(go.Scatter(go.Scatter(x=df["Date"], y=result[0],
mode='lines',
name='Train prediction')))
fig.add_trace(go.Scatter(x=df["Date"], y=result[1],
mode='lines',
name='Test prediction'))
fig.add_trace(go.Scatter(go.Scatter(x=df["Date"], y=result[2],
mode='lines',
name='Actual Value')))
fig.update_xaxes(
rangeslider_visible=True,
rangeselector=dict(
buttons=list([
dict(count=1, label="1m", step="month", stepmode="backward"),
dict(count=6, label="6m", step="month", stepmode="backward"),
dict(count=1, label="YTD", step="year", stepmode="todate"),
dict(count=1, label="1y", step="year", stepmode="backward"),
dict(count=1, label="2y", step="year", stepmode="backward"),
dict(step="all")
]),
font=dict(family='Rockwell',color='black'),
)
)
fig.show()